Shape

形状推断函数(Shape Inference Function)。该函数根据输入张量的形状和算子参数,推断输出张量的形状。该函数不区分数据类型,只处理张量的形状信息。

如果所有输出张量都是常量(ConstTensor 或 ConstScalar),则直接返回,不进行形状推断。否则,根据算子类型调用相应的形状推断函数。

支持的算子类型:
  • Arithmetic_InferShape - 算术运算的形状推断

  • Common_InferShape - 通用算子的形状推断

  • Softmax_InferShape - Softmax 算子的形状推断

  • MaxMinGrad_InferShape - MaxMin 梯度算子的形状推断

  • Dropout_InferShape - Dropout 算子的形状推断

  • DynamicQuant_InferShape - 动态量化算子的形状推断

  • Fft_InferShape - FFT 算子的形状推断

  • Flatten_InferShape - Flatten 算子的形状推断

  • LayerNorm_InferShape - LayerNorm 算子的形状推断

  • LogSoftmax_InferShape - LogSoftmax 算子的形状推断

输入:
  • inputs - 输入张量数组(TensorC** 类型)。

  • inputs_size - 输入张量的数量。

  • outputs - 输出张量数组(TensorC** 类型)。

  • outputs_size - 输出张量的数量。

  • param - 算子参数(OpParameter* 类型),包含算子类型和其他参数信息。

输出:
  • outputs - 输出张量数组,其中的形状信息会被更新。

支持平台:

FT78NE MT7004

备注

  • 该函数不区分数据类型,适用于所有数据类型

  • 函数会自动检查输出是否为常量,如果是常量则跳过形状推断

共享存储/私有存储版本:

void shape(TensorC **inputs, int inputs_size, TensorC **outputs, int outputs_size, OpParameter *param)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <shape.h>
 4
 5int main(int argc, char* argv[]) {
 6    TensorC** input_tensors_ptrs = (TensorC**)0x10010000;
 7    TensorC** output_tensors_ptrs = (TensorC**)0x10011000;
 8
 9    TensorC input0;
10    TensorC input1;
11    TensorC output;
12
13    int input0_shape[4] = {1,2,3,4};
14    int input1_shape[4] = {1,3,4};
15    int output_shape[4]; //不用初始化
16    memcpy(input0.shape_, input0_shape, 4 * sizeof(int));
17    input0.shape_size_ = 4;
18    memcpy(input1.shape_, input1_shape, 4 * sizeof(int));
19    input1.shape_size_ = 3;
20    input0.data_type_ = kNumberTypeFloat32;
21    input1.data_type_ = kNumberTypeFloat32;
22    input0.format_ = Format_NCHW;
23    input1.format_ = Format_NCHW;
24
25    input_tensors_ptrs[0] = &input0;
26    input_tensors_ptrs[1] = &input1;
27    output_tensors_ptrs[0] = &output;
28
29    ArithmeticParameter param;
30    param.op_parameter_.type_ = Arithmetic_InferShape;
31
32    shape(input_tensors_ptrs, 2, output_tensors_ptrs, 1, (OpParameter*)&param);
33    return 0;
34}